{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fairseq in Amazon SageMaker: Pre-trained English to French translation model\n",
    "\n",
    "In this notebook, we will show you how to serve an English to French translation model using pre-trained model provided by the [Fairseq toolkit](https://github.com/pytorch/fairseq)\n",
    "\n",
    "## Permissions\n",
    "\n",
    "Running this notebook requires permissions in addition to the regular SageMakerFullAccess permissions. This is because it creates new repositories in Amazon ECR. The easiest way to add these permissions is simply to add the managed policy AmazonEC2ContainerRegistryFullAccess to the role that you used to start your notebook instance. There's no need to restart your notebook instance when you do this, the new permissions will be available immediately.\n",
    "\n",
    "## Download pre-trained model\n",
    "\n",
    "Fairseq maintains their pre-trained models [here](https://github.com/pytorch/fairseq/blob/master/examples/translation/README.md). We will use the model that was pre-trained on the [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) dataset. As the models are archived in .bz2 format, we need to convert them to .tar.gz as this is the format supported by Amazon SageMaker.\n",
    "\n",
    "### Convert archive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sh\n",
    "\n",
    "wget https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2\n",
    "\n",
    "tar xvjf wmt14.v2.en-fr.fconv-py.tar.bz2 > /dev/null\n",
    "cd wmt14.en-fr.fconv-py\n",
    "mv model.pt checkpoint_best.pt\n",
    "\n",
    "tar czvf wmt14.en-fr.fconv-py.tar.gz checkpoint_best.pt dict.en.txt dict.fr.txt bpecodes README.md > /dev/null"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The pre-trained model has been downloaded and converted. The next step is upload the data to Amazon S3 in order to make it available for running the inference.\n",
    "\n",
    "### Upload data to Amazon S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "\n",
    "sagemaker_session = sagemaker.Session()\n",
    "region =  sagemaker_session.boto_session.region_name\n",
    "account = sagemaker_session.boto_session.client('sts').get_caller_identity().get('Account')\n",
    "\n",
    "bucket = sagemaker_session.default_bucket()\n",
    "prefix = 'sagemaker/DEMO-pytorch-fairseq/pre-trained-models'\n",
    "\n",
    "role = sagemaker.get_execution_role()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_model_location = sagemaker_session.upload_data(\n",
    "    path='wmt14.en-fr.fconv-py/wmt14.en-fr.fconv-py.tar.gz',\n",
    "    bucket=bucket,\n",
    "    key_prefix=prefix)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build Fairseq serving container\n",
    "\n",
    "Next we need to register a Docker image in Amazon SageMaker that will contain the Fairseq code and that will be pulled at inference time to perform the of the precitions from the pre-trained model we downloaded. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%sh\n",
    "chmod +x create_container.sh \n",
    "\n",
    "./create_container.sh pytorch-fairseq-serve"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Fairseq serving image has been pushed into Amazon ECR, the registry from which Amazon SageMaker will be able to pull that image and launch both training and prediction. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hosting the pre-trained model for inference\n",
    "\n",
    "We first needs to define a base JSONPredictor class that will help us with sending predictions to the model once it's hosted on the Amazon SageMaker endpoint. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer\n",
    "\n",
    "class JSONPredictor(RealTimePredictor):\n",
    "    def __init__(self, endpoint_name, sagemaker_session):\n",
    "        super(JSONPredictor, self).__init__(endpoint_name, sagemaker_session, json_serializer, json_deserializer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now use the Model class to deploy the model artificats (the pre-trained model), and deploy it on a CPU instance. Let's use a `ml.m5.xlarge`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker import Model\n",
    "\n",
    "algorithm_name = \"pytorch-fairseq-serve\"\n",
    "image = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, algorithm_name)\n",
    "\n",
    "model = Model(model_data=trained_model_location,\n",
    "              role=role,\n",
    "              image=image,\n",
    "              predictor_cls=JSONPredictor,\n",
    "             )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "predictor = model.deploy(initial_instance_count=1, instance_type='ml.m5.xlarge')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now it's your time to play. Input a sentence in English and get the translation in French by simply calling predict. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import html\n",
    "\n",
    "result = predictor.predict(\"I love translation\")\n",
    "# Some characters are escaped HTML-style requiring to unescape them before printing\n",
    "print(html.unescape(result))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once you're done with getting predictions, remember to shut down your endpoint as you no longer need it. \n",
    "\n",
    "## Delete endpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.sagemaker_session.delete_endpoint(predictor.endpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Voila! For more information, you can check out the [Fairseq toolkit homepage](https://github.com/pytorch/fairseq). "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_pytorch_p36",
   "language": "python",
   "name": "conda_pytorch_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
